{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "3b53ae0a-4cb5-44fd-bcde-5ce7abce79e9",
      "metadata": {
        "id": "3b53ae0a-4cb5-44fd-bcde-5ce7abce79e9"
      },
      "source": [
        "# Introduction\n",
        "\n",
        "The goal of this tutorial is to demonstrate the basic steps required to setup and train a simple single-channel speech enhancement model in NeMo using online augmentation with noise and room impulse responce (RIR). Online augmentation is performed using a dataloader based on Lhotse speech data processing toolkit [1].\n",
        "\n",
        "\n",
        "This notebook covers the following steps:\n",
        "\n",
        "* Download speech, noise and RIR data\n",
        "* Prepare Lhotse manifests for speech, noise and RIR data\n",
        "* Prepare fixed validation set by mixing speech, noise and RIR data\n",
        "* Configure and train a simple single-output model\n",
        "\n",
        "Note that this tutorial is only for demonstration purposes.\n",
        "To achieve best performance for a particular use case, carefully prepared data and more advanced models should be used.\n",
        "\n",
        "*Disclaimer:*\n",
        "User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3d603b38-b63c-42a6-92b3-575925149bde",
      "metadata": {
        "id": "3d603b38-b63c-42a6-92b3-575925149bde"
      },
      "outputs": [],
      "source": [
        "\"\"\"\n",
        "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
        "\n",
        "Instructions for setting up Colab are as follows:\n",
        "1. Open a new Python 3 notebook.\n",
        "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
        "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
        "4. Run this cell to set up dependencies.\n",
        "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
        "\"\"\"\n",
        "\n",
        "GIT_USER = 'NVIDIA'\n",
        "BRANCH = 'main'\n",
        "\n",
        "if 'google.colab' in str(get_ipython()):\n",
        "\n",
        "    # Install dependencies\n",
        "    !pip install wget\n",
        "    !apt-get install sox libsndfile1 ffmpeg\n",
        "    !pip install text-unidecode\n",
        "    !pip install matplotlib>=3.3.2\n",
        "\n",
        "    ## Install NeMo\n",
        "    !python -m pip install git+https://github.com/{GIT_USER}/NeMo.git@{BRANCH}#egg=nemo_toolkit[all]\n",
        "\n",
        "    ## Install TorchAudio\n",
        "    !pip install torchaudio>=0.13.0 -f https://download.pytorch.org/whl/torch_stable.html"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bdac6ac8-a21f-4ea8-8484-41a6d187a2fe",
      "metadata": {
        "id": "bdac6ac8-a21f-4ea8-8484-41a6d187a2fe"
      },
      "source": [
        "The following cell will take care of the necessary imports and prepare utility functions used throughout the notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "baca13c6-b0ed-429b-93b6-77249fdf4710",
      "metadata": {
        "id": "baca13c6-b0ed-429b-93b6-77249fdf4710"
      },
      "outputs": [],
      "source": [
        "import glob\n",
        "import librosa\n",
        "import os\n",
        "import torch\n",
        "import tqdm\n",
        "from itertools import islice\n",
        "\n",
        "import IPython.display as ipd\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import lightning.pytorch as pl\n",
        "import soundfile as sf\n",
        "from pathlib import Path\n",
        "from omegaconf import OmegaConf, open_dict\n",
        "from sklearn.model_selection import train_test_split\n",
        "from torchmetrics.functional.audio import signal_distortion_ratio, scale_invariant_signal_distortion_ratio\n",
        "from lhotse import CutSet, RecordingSet, Recording, MonoCut\n",
        "from lhotse.recipes import (\n",
        "    download_rir_noise,\n",
        "    prepare_rir_noise,\n",
        "    download_librispeech,\n",
        "    prepare_librispeech\n",
        ")\n",
        "\n",
        "from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config\n",
        "from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fc0fafa4-bc65-4066-8e75-d740e2d15259",
      "metadata": {
        "id": "fc0fafa4-bc65-4066-8e75-d740e2d15259"
      },
      "source": [
        "Utility functions for displaying signals and metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "49720c06-3b4d-45b6-a054-73800c3bddea",
      "metadata": {
        "id": "49720c06-3b4d-45b6-a054-73800c3bddea"
      },
      "outputs": [],
      "source": [
        "def show_signal(signal: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
        "    \"\"\"Show the time-domain signal and its spectrogram.\n",
        "    \"\"\"\n",
        "    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 2.5))\n",
        "\n",
        "    # show waveform\n",
        "    t = np.arange(0, len(signal)) / sample_rate\n",
        "\n",
        "    ax[0].plot(t, signal)\n",
        "    ax[0].set_xlim(0, t.max())\n",
        "    ax[0].grid()\n",
        "    ax[0].set_xlabel('time / s')\n",
        "    ax[0].set_ylabel('amplitude')\n",
        "    ax[0].set_title(tag)\n",
        "\n",
        "    n_fft = 1024\n",
        "    hop_length = 256\n",
        "\n",
        "    D = librosa.amplitude_to_db(np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length)), ref=np.max)\n",
        "    img = librosa.display.specshow(D, y_axis='linear', x_axis='time', sr=sample_rate, n_fft=n_fft, hop_length=hop_length, ax=ax[1])\n",
        "    ax[1].set_title(tag)\n",
        "\n",
        "    plt.tight_layout()\n",
        "    plt.colorbar(img, format=\"%+2.f dB\", ax=ax)\n",
        "\n",
        "def show_metrics(signal: np.ndarray, reference: np.ndarray, sample_rate: int = 16000, tag: str = 'Signal'):\n",
        "    \"\"\"Show metrics for the time-domain signal and the reference signal.\n",
        "    \"\"\"\n",
        "    sdr = signal_distortion_ratio(preds=torch.tensor(signal), target=torch.tensor(reference))\n",
        "    sisdr = scale_invariant_signal_distortion_ratio(preds=torch.tensor(signal), target=torch.tensor(reference))\n",
        "    print(tag)\n",
        "    print('\\tsdr:  ', sdr.item())\n",
        "    print('\\tsisdr:', sisdr.item())"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "653eb22f-b09e-4421-8028-05d123fc47a5",
      "metadata": {
        "id": "653eb22f-b09e-4421-8028-05d123fc47a5"
      },
      "source": [
        "### Data preparation"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "42deb721-f734-43ea-bfef-931733d6379b",
      "metadata": {
        "id": "42deb721-f734-43ea-bfef-931733d6379b"
      },
      "source": [
        "In this notebook, it is assumed that all audio will be resampled to 16kHz and the data and configuration will be stored under `root_dir` as defined below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "577d4be2-10f8-4050-8f40-b2566ef9dfef",
      "metadata": {
        "id": "577d4be2-10f8-4050-8f40-b2566ef9dfef"
      },
      "outputs": [],
      "source": [
        "# sample rate used throughout the notebook\n",
        "sample_rate = 16000\n",
        "\n",
        "# root directory for data preparation, configurations, etc\n",
        "root_dir = Path('./')\n",
        "\n",
        "# data directory\n",
        "data_dir = root_dir / 'data'\n",
        "data_dir.mkdir(exist_ok=True)\n",
        "\n",
        "# scripts directory\n",
        "scripts_dir = root_dir / 'scripts'\n",
        "scripts_dir.mkdir(exist_ok=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9f6dc925-04b1-4118-8092-c3107a734f4f",
      "metadata": {
        "id": "9f6dc925-04b1-4118-8092-c3107a734f4f"
      },
      "source": [
        "Create dictionary with paths for all of the manifests files which will be stored under `data_dir`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "da232926-a529-4409-b983-8a921c578c91",
      "metadata": {
        "id": "da232926-a529-4409-b983-8a921c578c91"
      },
      "outputs": [],
      "source": [
        "dataset_manifest = {\n",
        "    'speech_train': data_dir / 'libri_cuts_train.jsonl.gz',\n",
        "    'speech_val': data_dir / 'libri_cuts_val.jsonl.gz',\n",
        "    'noise_train': data_dir / 'demand_cuts_train.jsonl.gz',\n",
        "    'noise_val': data_dir / 'demand_cuts_val.jsonl.gz',\n",
        "    'rir_train': data_dir / 'rir_recordings_train.jsonl.gz',\n",
        "    'rir_val': data_dir / 'rir_recordings_val.jsonl.gz',\n",
        "    'noisy_val': data_dir / 'noisy_cuts_val.jsonl.gz'\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d0e19997-b151-47f5-b19b-97d8c7e65e2e",
      "metadata": {
        "id": "d0e19997-b151-47f5-b19b-97d8c7e65e2e"
      },
      "source": [
        "In this tutorial, a subset of LibriSpeech dataset [2] will be downloaded and used as the speech material.\n",
        "\n",
        "To use a dataset with the Lhotse dataloader, we need to create manifest files from Lhotse cuts (refer to [3] for the details). In this cell, we first download and prepare the LibriSpeech dataset in a Lhotse format and then save it as manifest files for training and validation sets. Note that the target recording in the speech enhancement task is the original (unchanged) clean speech signal, which is defined under the custom field \"target_recording\" in the cuts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "d4156a68-1906-4a4c-9ecf-cab5d79eb13b",
      "metadata": {
        "id": "d4156a68-1906-4a4c-9ecf-cab5d79eb13b"
      },
      "outputs": [],
      "source": [
        "libri_variant = 'mini_librispeech'\n",
        "speech_dir = data_dir / 'speech'\n",
        "\n",
        "libri_root = download_librispeech(speech_dir, dataset_parts=libri_variant)\n",
        "\n",
        "# Use script from Lhotse to prepate Librispeech dataset to Lhotse format\n",
        "libri = prepare_librispeech(\n",
        "    libri_root, dataset_parts=libri_variant,\n",
        ")\n",
        "cuts_train = CutSet.from_manifests(**libri[\"train-clean-5\"]).trim_to_supervisions()\n",
        "cuts_val = CutSet.from_manifests(**libri[\"dev-clean-2\"]).trim_to_supervisions()\n",
        "\n",
        "# Save the manifest with a custom \"target_recording\"\n",
        "with CutSet.open_writer(dataset_manifest['speech_train']) as writer:\n",
        "    for cut in cuts_train:\n",
        "        cut.target_recording = cut.recording\n",
        "        writer.write(cut)\n",
        "\n",
        "with CutSet.open_writer(dataset_manifest['speech_val']) as writer:\n",
        "    for cut in cuts_val:\n",
        "        cut.target_recording = cut.recording\n",
        "        writer.write(cut)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "6f5a0eb7-b913-4323-a8ce-1da77521730b",
      "metadata": {
        "id": "6f5a0eb7-b913-4323-a8ce-1da77521730b"
      },
      "source": [
        "During the training phase, noise data will be used for online augmentation by mixing it with the downloaded speech on-the-fly. During the validation and test phases, the noise will be used to create fixed sets.\n",
        "\n",
        "The following cell will download and prepare the noise data using a subset of the DEMAND dataset [4]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ba4dcfe4-5614-438d-9c2b-73c0fbf3bc97",
      "metadata": {
        "id": "ba4dcfe4-5614-438d-9c2b-73c0fbf3bc97"
      },
      "outputs": [],
      "source": [
        "noise_dir = data_dir / 'noise'\n",
        "noise_data_set = 'STRAFFIC,PSTATION'\n",
        "\n",
        "# Copy script\n",
        "get_demand_script = os.path.join(scripts_dir, 'get_demand_data.py')\n",
        "if not os.path.exists(get_demand_script):\n",
        "    !wget -P $scripts_dir https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/scripts/dataset_processing/get_demand_data.py\n",
        "\n",
        "if not noise_dir.is_dir():\n",
        "    noise_dir.mkdir(exist_ok=True)\n",
        "    !python {get_demand_script} --data_root={noise_dir} --data_sets={noise_data_set}\n",
        "else:\n",
        "    print('Noise directory already exists in:', noise_dir)\n",
        "\n",
        "noise_dir = data_dir / 'noise'\n",
        "demand_recordings = RecordingSet.from_dir(noise_dir, pattern='*.wav')\n",
        "\n",
        "demand_cuts = CutSet.from_manifests(recordings=demand_recordings)\n",
        "shuffled_demand_cuts = demand_cuts.shuffle()\n",
        "\n",
        "demand_cuts_train = shuffled_demand_cuts.subset(last=len(shuffled_demand_cuts)-3)\n",
        "demand_cuts_val = shuffled_demand_cuts.subset(first=3)\n",
        "\n",
        "demand_cuts_train.to_file(dataset_manifest['noise_train'])\n",
        "demand_cuts_val.to_file(dataset_manifest['noise_val'])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f9429dd5-fedc-486d-92f1-5a0255d600e6",
      "metadata": {
        "id": "f9429dd5-fedc-486d-92f1-5a0255d600e6"
      },
      "source": [
        "The following cell will download and prepare a simulated subset from room impulse responses dataset, described in the following paper [5]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3e83f5d2-26cd-44cb-9efe-4eafcd709bfe",
      "metadata": {
        "id": "3e83f5d2-26cd-44cb-9efe-4eafcd709bfe"
      },
      "outputs": [],
      "source": [
        "rir_recordings = RecordingSet()\n",
        "rir_raw_dir = download_rir_noise(data_dir)\n",
        "rirs = prepare_rir_noise(rir_raw_dir, parts=[\"sim_rir\"])\n",
        "rir_recordings = rirs[\"sim_rir\"][\"recordings\"]\n",
        "shuffled_rir_recordings = rir_recordings.shuffle()\n",
        "\n",
        "rir_val_part = int(len(rir_recordings) * 0.1)\n",
        "rir_train_part = len(rir_recordings) - rir_val_part\n",
        "\n",
        "rir_recordings_train = shuffled_rir_recordings.subset(last=rir_train_part)\n",
        "rir_recordings_val = shuffled_rir_recordings.subset(first=rir_val_part)\n",
        "\n",
        "rir_recordings_train.to_file(dataset_manifest['rir_train'])\n",
        "rir_recordings_val.to_file(dataset_manifest['rir_val'])"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b568a03e-57a1-412f-a3a2-85af3a1e3a20",
      "metadata": {
        "id": "b568a03e-57a1-412f-a3a2-85af3a1e3a20"
      },
      "source": [
        "For this tutorial, a single-channel noisy validation set is constructed by adding speech and noise.\n",
        "\n",
        "The following block will use based on Lhotse data loader from NeMo to create fixed noisy validation set and save it do `data/val` folder."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "62dc5755-ccc5-4eb8-80d4-d59b418406a3",
      "metadata": {
        "id": "62dc5755-ccc5-4eb8-80d4-d59b418406a3"
      },
      "outputs": [],
      "source": [
        "# Create the cofing for the Lhotse data loader\n",
        "val_noise_config = {\n",
        "'cuts_path': dataset_manifest['speech_val'].as_posix(), # path to Lhotse cuts manifest with speech signals for augmentation\n",
        "'sample_rate': 16000,\n",
        "'batch_size': 1,\n",
        "'rir_enabled': True, # enable room impulse response augmentation\n",
        "'rir_path': dataset_manifest['rir_val'].as_posix(), # path to Lhotse recordings manifest with room impulse response signals\n",
        "'rir_prob': 1.0, # probability of applying RIR augmentation\n",
        "'noise_path': dataset_manifest['noise_val'].as_posix(), # path to Lhotse cuts manifest with noise signals\n",
        "'noise_mix_prob': 1.0,  # probability of applying noise augmentation\n",
        "'noise_snr':  (0, 20), # range of speech-to-noise ratio for the noise augmentation\n",
        "'shuffle': False\n",
        "}\n",
        "\n",
        "# Instantiate the data loader\n",
        "dl = get_lhotse_dataloader_from_config(\n",
        "OmegaConf.create(val_noise_config), global_rank=0, world_size=1, dataset=LhotseAudioToTargetDataset()\n",
        ")\n",
        "\n",
        "# Define number of samples for the validation set\n",
        "num_examples = 100\n",
        "print(f'Get {num_examples} samples for the validation set')\n",
        "samples = [sample for sample in islice(dl, num_examples)]\n",
        "\n",
        "\n",
        "# Create folders for saving noisy (input) and clean (target) samples\n",
        "val_dir = data_dir / 'val'\n",
        "val_noisy_dir = val_dir / 'noisy'\n",
        "val_clean_dir = val_dir / 'clean'\n",
        "\n",
        "val_dir.mkdir(exist_ok=True)\n",
        "val_noisy_dir.mkdir(exist_ok =True)\n",
        "val_clean_dir.mkdir(exist_ok=True)\n",
        "\n",
        "val_noisy_basename = 'val_noisy_fileid'\n",
        "val_clean_basename = 'val_clean_fileid'\n",
        "\n",
        "with CutSet.open_writer(dataset_manifest['noisy_val']) as writer:\n",
        "    for n, sample in enumerate(samples):\n",
        "        noisy, clean = sample['input_signal'].numpy()[0], sample['target_signal'].numpy()[0]\n",
        "        #Save\n",
        "        sf.write(val_noisy_dir / f'{val_noisy_basename}_{str(n)}.wav', noisy, samplerate=val_noise_config['sample_rate'])\n",
        "        sf.write(val_clean_dir / f'{val_clean_basename}_{str(n)}.wav', clean, samplerate=val_noise_config['sample_rate'])\n",
        "        noisy_rec = Recording.from_file(val_noisy_dir / f'{val_noisy_basename}_{str(n)}.wav')\n",
        "        clean_rec = Recording.from_file(val_clean_dir / f'{val_clean_basename}_{str(n)}.wav')\n",
        "\n",
        "        val_cut = MonoCut(id=noisy_rec.id,\n",
        "                start=0,\n",
        "                duration=noisy_rec.duration,\n",
        "                channel=0,\n",
        "                recording=noisy_rec)\n",
        "        val_cut.target_recording = clean_rec\n",
        "        writer.write(val_cut)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f108c5d2-c87b-47a7-8c2d-d363a9234abb",
      "metadata": {
        "id": "f108c5d2-c87b-47a7-8c2d-d363a9234abb"
      },
      "source": [
        "### Model configuration\n",
        "\n",
        "Here, a simple encoder-mask-decoder model will be used to process the noisy input signal and produce an enhanced output signal.\n",
        "\n",
        "In general, an encoder-mask-decoder model can be configured using `EncMaskDecAudioToAudioModel` class, which is depicted in the following block diagram."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c0ff2bff-1637-4a17-85b2-c92307a4d8d7",
      "metadata": {
        "id": "c0ff2bff-1637-4a17-85b2-c92307a4d8d7"
      },
      "source": [
        "<img src=\"https://github.com/NVIDIA/NeMo/releases/download/v1.18.0/encmaskdecoder_model.png\" alt=\"encmaskdecoder_model\" style=\"width: 800px;\"/>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "470b0a50-fe6c-4ded-9d27-64db20f83cad",
      "metadata": {
        "id": "470b0a50-fe6c-4ded-9d27-64db20f83cad"
      },
      "source": [
        "The model structure can briefly be described as follows:\n",
        "* Input to the model is a time-domain signal.\n",
        "* Encoder transforms the input signal to the analysis domain.\n",
        "* Mask estimator estimates a mask used to generate the output signal.\n",
        "* Mask processor combines the estimated mask and the encoded input to produce the encoded output.\n",
        "* Decoder transforms the encoded output into a time-domain signal.\n",
        "* Output is a time-domain signal."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4cf1cfde-913f-42da-a9e0-fdb5ae8e50c5",
      "metadata": {
        "id": "4cf1cfde-913f-42da-a9e0-fdb5ae8e50c5"
      },
      "source": [
        "For this example, the model will be configured to use a fixed short-time Fourier transform-based encoder and decoder, and the mask will be estimated using a recurrent neural network. The model used here is depicted in the following block diagram."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8ed37321-2e8b-4d8e-90bf-efc0897517e5",
      "metadata": {
        "id": "8ed37321-2e8b-4d8e-90bf-efc0897517e5"
      },
      "source": [
        "<img src=\"https://github.com/NVIDIA/NeMo/releases/download/v1.18.0/single_output_example_model.png\" alt=\"single_output_example_model\" style=\"width: 1000px;\"/>"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "36596c01-db7d-402f-a5df-0fb9b4641b08",
      "metadata": {
        "id": "36596c01-db7d-402f-a5df-0fb9b4641b08"
      },
      "source": [
        "In this particular configuration, the model structure can be described as follows:\n",
        "* `AudioToSpectrogram` implements the analysis STFT transform.\n",
        "* `MaskEstimatorRNN` is a mask estimator using RNNs.\n",
        "* `MaskReferenceChannel` is a simple processor which applies the estimated mask on the reference channel. In this tutorial, the input signal has only a single channel, so the reference channel will be set to `0`.\n",
        "* `SpectrogramToAudio` implements the synthesis STFT transform."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "03546307-1c4e-4775-a754-276c9a69be5c",
      "metadata": {
        "id": "03546307-1c4e-4775-a754-276c9a69be5c"
      },
      "source": [
        "The following cell will load and show the default configuration for the model depicted above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "03952acf-1792-4470-9120-dddf20a9f646",
      "metadata": {
        "scrolled": true,
        "id": "03952acf-1792-4470-9120-dddf20a9f646"
      },
      "outputs": [],
      "source": [
        "config_dir = root_dir / 'conf'\n",
        "config_dir.mkdir(exist_ok=True)\n",
        "\n",
        "config_path = config_dir / 'masking_with_online_augmentation.yaml'\n",
        "\n",
        "if not config_path.is_file():\n",
        "    !wget https://raw.githubusercontent.com/{GIT_USER}/NeMo/{BRANCH}/examples/audio/conf/masking_with_online_augmentation.yaml -P {config_dir.as_posix()}\n",
        "\n",
        "config = OmegaConf.load(config_path)\n",
        "config = OmegaConf.to_container(config, resolve=True)\n",
        "config = OmegaConf.create(config)\n",
        "\n",
        "print('Loaded config')\n",
        "print(OmegaConf.to_yaml(config))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "1c872345-3bfd-4933-8c51-dbb62fce790a",
      "metadata": {
        "id": "1c872345-3bfd-4933-8c51-dbb62fce790a"
      },
      "source": [
        "Training dataset is configured with the following parameters\n",
        "* `cuts_path` points to a Lhotse manifest file, containing speech samples\n",
        "* `noise_path` poins to a Lhotse manifest file, containing noise samples\n",
        "* `noise_mix_prob` defines the probabilty with which noise will be added during training\n",
        "* `noise_snr` defines an SNR range for mixing noise samples\n",
        "* `rir_enabled` enables room impulse response agmentation\n",
        "* `rir_path` points to a Lhotse manifest file, containing RIR samples\n",
        "* `rir_prob` defines the probabilty with which RIR will be added during training\n",
        "  \n",
        "For the validation and test sets only `cuts_path` parameter is used since the `val` manifest already contains noisy and clean samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "e19c4209-b10b-4924-a3de-e792de57c313",
      "metadata": {
        "id": "e19c4209-b10b-4924-a3de-e792de57c313"
      },
      "outputs": [],
      "source": [
        "# Setup training dataset\n",
        "config.model.train_ds.cuts_path = dataset_manifest['speech_train'].as_posix() # path to Lhotse cuts manifest with speech signals for augmentation\n",
        "config.model.train_ds.noise_path =  dataset_manifest['noise_train'].as_posix() # path to Lhotse cuts manifest with noise signals\n",
        "config.model.train_ds.noise_mix_prob =  1.0 # probability of applying noise augmentation\n",
        "config.model.train_ds.noise_snr =  (0, 20) # range of speech-to-noise ratio for the noise augmentation\n",
        "config.model.train_ds.rir_enabled = True # enable room impulse response augmentation\n",
        "config.model.train_ds.rir_path =  dataset_manifest['rir_val'].as_posix() # path to Lhotse recordings manifest with room impulse response signals\n",
        "config.model.train_ds.rir_prob = 1.0 # probability of applying RIR augmentation\n",
        "\n",
        "config.model.validation_ds.cuts_path = dataset_manifest['noisy_val'].as_posix() # fixed noisy validation set\n",
        "\n",
        "config.model.test_ds.cuts_path = dataset_manifest['noisy_val'].as_posix() # fixed noisy test set\n",
        "\n",
        "\n",
        "print(\"Train dataset config:\")\n",
        "print(OmegaConf.to_yaml(config.model.train_ds))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "534aef42-d3a7-4477-8f89-8fe780f2b5f7",
      "metadata": {
        "id": "534aef42-d3a7-4477-8f89-8fe780f2b5f7"
      },
      "source": [
        "Metrics for validation and test set are configured in the following cell.\n",
        "\n",
        "In this tutorial, signal-to-distortion ratio (SDR) and scale-invariant SDR from torch metrics are used [5]."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "522f1065-fe23-4101-b6a9-0706cc8db389",
      "metadata": {
        "id": "522f1065-fe23-4101-b6a9-0706cc8db389"
      },
      "outputs": [],
      "source": [
        "# Setup metrics to compute on validation and test sets\n",
        "metrics = OmegaConf.create({\n",
        "    'sisdr': {\n",
        "        '_target_': 'torchmetrics.audio.ScaleInvariantSignalDistortionRatio',\n",
        "    },\n",
        "    'sdr': {\n",
        "        '_target_': 'torchmetrics.audio.SignalDistortionRatio',\n",
        "    }\n",
        "})\n",
        "config.model.metrics.val = metrics\n",
        "config.model.metrics.test = metrics\n",
        "\n",
        "print(\"Metrics config:\")\n",
        "print(OmegaConf.to_yaml(config.model.metrics))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4771d02f-3b01-481c-8894-bb43644a941a",
      "metadata": {
        "id": "4771d02f-3b01-481c-8894-bb43644a941a"
      },
      "source": [
        "### Trainer configuration\n",
        "NeMo models are primarily PyTorch Lightning modules and therefore are entirely compatible with the PyTorch Lightning ecosystem."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2208b407-07e2-4bb0-bfad-2ef488acf217",
      "metadata": {
        "id": "2208b407-07e2-4bb0-bfad-2ef488acf217"
      },
      "outputs": [],
      "source": [
        "print(\"Trainer config:\")\n",
        "print(OmegaConf.to_yaml(config.trainer))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f2dac35a-b487-4037-907d-326f321564ca",
      "metadata": {
        "id": "f2dac35a-b487-4037-907d-326f321564ca"
      },
      "source": [
        "We can modify some trainer configs for this tutorial.\n",
        "Most importantly, the number of epochs is set to a small value, to limit the runtime for the purpose of this example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2c18dbd2-ad22-4eae-912a-56d13aa74a6b",
      "metadata": {
        "id": "2c18dbd2-ad22-4eae-912a-56d13aa74a6b"
      },
      "outputs": [],
      "source": [
        "# Checks if we have GPU available and uses it\n",
        "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n",
        "config.trainer.devices = 1\n",
        "config.trainer.accelerator = accelerator\n",
        "\n",
        "# Reduces maximum number of epochs for quick demonstration\n",
        "config.trainer.max_epochs = 30\n",
        "\n",
        "# Remove distributed training flags\n",
        "config.trainer.strategy = 'auto'\n",
        "\n",
        "# Instantiate the trainer\n",
        "trainer = pl.Trainer(**config.trainer)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c1198508-7fb5-4ba7-a4e5-ad15153037a1",
      "metadata": {
        "id": "c1198508-7fb5-4ba7-a4e5-ad15153037a1"
      },
      "source": [
        "### Experiment manager\n",
        "\n",
        "NeMo has an experiment manager that handles logging and checkpointing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "72642c8c-97f5-4b6c-9e63-030c7a07810c",
      "metadata": {
        "scrolled": true,
        "id": "72642c8c-97f5-4b6c-9e63-030c7a07810c"
      },
      "outputs": [],
      "source": [
        "from nemo.utils.exp_manager import exp_manager\n",
        "\n",
        "exp_dir = exp_manager(trainer, config.get(\"exp_manager\", None))\n",
        "# The exp_dir provides a path to the current experiment for easy access\n",
        "\n",
        "print(\"Experiment directory:\")\n",
        "print(exp_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "00562a21-0c0f-4c53-b89e-344431e75a42",
      "metadata": {
        "id": "00562a21-0c0f-4c53-b89e-344431e75a42"
      },
      "source": [
        "### Model instantiation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0b73c3de-6ded-491d-8026-8a093f7ffc16",
      "metadata": {
        "scrolled": true,
        "id": "0b73c3de-6ded-491d-8026-8a093f7ffc16"
      },
      "outputs": [],
      "source": [
        "from nemo.collections import audio as nemo_audio\n",
        "\n",
        "enhancement_model = nemo_audio.models.EncMaskDecAudioToAudioModel(cfg=config.model, trainer=trainer)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a8f4373e-6aca-41d3-a69d-5bef53a88f20",
      "metadata": {
        "id": "a8f4373e-6aca-41d3-a69d-5bef53a88f20"
      },
      "source": [
        "### Training\n",
        "Create a Tensorboard visualization to monitor progress"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a253ce7c-6cca-441f-818e-1c8a2f00a9f5",
      "metadata": {
        "id": "a253ce7c-6cca-441f-818e-1c8a2f00a9f5"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "    from google import colab\n",
        "    COLAB_ENV = True\n",
        "except (ImportError, ModuleNotFoundError):\n",
        "    COLAB_ENV = False\n",
        "\n",
        "# Load the TensorBoard notebook extension\n",
        "if COLAB_ENV:\n",
        "    %load_ext tensorboard\n",
        "    %tensorboard --logdir {exp_dir}\n",
        "else:\n",
        "    print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9183a35d-8332-49f6-96b3-3d14dd70ff4e",
      "metadata": {
        "id": "9183a35d-8332-49f6-96b3-3d14dd70ff4e"
      },
      "source": [
        "Training can be started using `trainer.fit`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0eab4e7e-3004-49fa-a09b-10dfd12ba294",
      "metadata": {
        "scrolled": true,
        "id": "0eab4e7e-3004-49fa-a09b-10dfd12ba294"
      },
      "outputs": [],
      "source": [
        "trainer.fit(enhancement_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "af906f16-626c-4ca4-a465-fb0cca8d514a",
      "metadata": {
        "id": "af906f16-626c-4ca4-a465-fb0cca8d514a"
      },
      "source": [
        "After the training is completed, the configured metrics can be easily computed on the test set as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "589af425-18b2-4445-8c88-d3161e9928c5",
      "metadata": {
        "id": "589af425-18b2-4445-8c88-d3161e9928c5"
      },
      "outputs": [],
      "source": [
        "trainer.test(enhancement_model, ckpt_path=None)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cce96446-a969-4c6f-a865-968ee61119a8",
      "metadata": {
        "id": "cce96446-a969-4c6f-a865-968ee61119a8"
      },
      "source": [
        "### Inference\n",
        "\n",
        "The following cell provides an example of inference on an single audio file.\n",
        "For simplicity, the audio file information is taken from the test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9f29ade5-424e-4a7a-81b7-dc3c5cc90bf7",
      "metadata": {
        "id": "9f29ade5-424e-4a7a-81b7-dc3c5cc90bf7"
      },
      "outputs": [],
      "source": [
        "# Load 10 samples from test_dataloader\n",
        "samples = [sample for sample in islice(enhancement_model.test_dataloader(), 10)]\n",
        "\n",
        "# Different sample can be used via list index\n",
        "sample = samples[0]\n",
        "\n",
        "noisy_tensor = sample['input_signal']\n",
        "speech_tensor = sample['target_signal']\n",
        "\n",
        "# Get the one-dimentional numpy signals for the plotting audio files and metrics calculation\n",
        "noisy_signal = noisy_tensor.squeeze(0).numpy()\n",
        "speech_signal = speech_tensor.squeeze(0).numpy()\n",
        "\n",
        "\n",
        "# Move to device\n",
        "device = 'cuda' if accelerator == 'gpu' else 'cpu'\n",
        "enhancement_model = enhancement_model.to(device)\n",
        "\n",
        "# Process using the model\n",
        "with torch.no_grad():\n",
        "    output_tensor, _ = enhancement_model(input_signal=noisy_tensor.unsqueeze(1).cuda())\n",
        "output_signal = output_tensor[0][0].detach().cpu().numpy()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5110ee6c-3939-4d70-a4d6-6422de31da51",
      "metadata": {
        "id": "5110ee6c-3939-4d70-a4d6-6422de31da51"
      },
      "source": [
        "Signals can be easily plotted and signal metrics can be calculated for the given example."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0275a032-969c-4e31-b8b5-8a61cb1f82ee",
      "metadata": {
        "id": "0275a032-969c-4e31-b8b5-8a61cb1f82ee"
      },
      "outputs": [],
      "source": [
        "# Show noisy and clean signals\n",
        "show_metrics(signal=noisy_signal, reference=speech_signal, tag='Noisy signal', sample_rate=sample_rate)\n",
        "show_metrics(signal=output_signal, reference=speech_signal, tag='Output signal', sample_rate=sample_rate)\n",
        "\n",
        "# Show signals\n",
        "show_signal(speech_signal, tag='Speech signal')\n",
        "show_signal(noisy_signal, tag='Noisy signal')\n",
        "show_signal(output_signal, tag='Output signal')\n",
        "\n",
        "# Play audio\n",
        "print('Speech signal')\n",
        "ipd.display(ipd.Audio(speech_signal, rate=sample_rate))\n",
        "\n",
        "print('Noisy signal')\n",
        "ipd.display(ipd.Audio(noisy_signal, rate=sample_rate))\n",
        "\n",
        "print('Output signal')\n",
        "ipd.display(ipd.Audio(output_signal, rate=sample_rate))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "89eb3478-bf82-4fbd-aa7d-0a55edf57f3a",
      "metadata": {
        "id": "89eb3478-bf82-4fbd-aa7d-0a55edf57f3a"
      },
      "source": [
        "## Next steps\n",
        "This is a simple tutorial which can serve as a starting point for prototyping and experimentation with audio-to-audio models.\n",
        "A processed audio output can be used, for example, for ASR or TTS.\n",
        "\n",
        "For more details about NeMo models and applications in in ASR and TTS, we recommend you checkout other tutorials next:\n",
        "\n",
        "* [NeMo fundamentals](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/00_NeMo_Primer.ipynb)\n",
        "* [NeMo models](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/01_NeMo_Models.ipynb)\n",
        "* [Speech Recognition](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/asr/ASR_with_NeMo.ipynb)\n",
        "* [Speech Synthesis](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/tts/Inference_ModelSelect.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "31a162ae-6b9e-40e4-a468-fbbad3360b11",
      "metadata": {
        "id": "31a162ae-6b9e-40e4-a468-fbbad3360b11"
      },
      "source": [
        "## References\n",
        "\n",
        "[1] Żelasko, P., Povey, D., Trmal, J., & Khudanpur, S. (2021). Lhotse: a speech data representation library for the modern deep learning ecosystem. https://arxiv.org/abs/2110.12561\n",
        "\n",
        "[2] V. Panayotov, G. Chen, D. Povery, S. Khudanpur, \"LibriSpeech: An ASR corpus based on public domain audio books,\" ICASSP 2015\n",
        "\n",
        "[3] Lhotse documentation, https://lhotse.readthedocs.io/\n",
        "\n",
        "[4] J. Thieman, N. Ito, V. Emmanuel, \"DEMAND: collection of multi-channel recordings of acoustic noise in diverse environments,\" ICA 2013\n",
        "\n",
        "[5] T. Ko, V. Peddinti, D. Povey, M. L. Seltzer and S. Khudanpur, \"A study on data augmentation of reverberant speech for robust speech recognition,\" 2017 ICASSP\n",
        "\n",
        "[6] https://github.com/Lightning-AI/torchmetrics"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    },
    "colab": {
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
